#!/usr/bin/env python3

from boututils.run_wrapper import shell, shell_safe, launch_safe
from boutdata.collect import collect
from netCDF4 import Dataset
import numpy
import os.path
from sys import stdout, exit

print("Making griddata test")
shell_safe("make > make.log")

nx = 4
ny = 24
blocksize = ny/6

# first generate some grid files to test
# double null case:
for n_yguards in [0, 1, 2]:
    datadir = "data-doublenull-" + str(n_yguards)
    gridname = "grid-doublenull-" + str(n_yguards) + ".nc"

    with Dataset(os.path.join(datadir,gridname), 'w') as gridfile:
        gridfile.createDimension('x', nx)
        gridfile.createDimension('y', ny + 4*n_yguards)

        gridfile.createVariable('nx', numpy.int32)
        gridfile['nx'][...] = nx

        gridfile.createVariable('ny', numpy.int32)
        gridfile['ny'][...] = ny

        gridfile.createVariable('y_boundary_guards', numpy.int32)
        gridfile['y_boundary_guards'][...] = n_yguards

        gridfile.createVariable('MXG', numpy.int32)
        gridfile['MXG'][...] = 1

        gridfile.createVariable('MYG', numpy.int32)
        gridfile['MYG'][...] = 2 if n_yguards==0 else n_yguards

        gridfile.createVariable('ixseps1', numpy.int32)
        gridfile['ixseps1'][...] = nx//2 - 1

        gridfile.createVariable('ixseps2', numpy.int32)
        gridfile['ixseps2'][...] = nx//2  - 1

        gridfile.createVariable('jyseps1_1', numpy.int32)
        gridfile['jyseps1_1'][...] = blocksize - 1

        gridfile.createVariable('jyseps2_1', numpy.int32)
        gridfile['jyseps2_1'][...] = 2*blocksize - 1

        gridfile.createVariable('ny_inner', numpy.int32)
        gridfile['ny_inner'][...] = 3*blocksize

        gridfile.createVariable('jyseps1_2', numpy.int32)
        gridfile['jyseps1_2'][...] = 4*blocksize - 1

        gridfile.createVariable('jyseps2_2', numpy.int32)
        gridfile['jyseps2_2'][...] = 5*blocksize - 1

        testdata = numpy.zeros([nx, ny + 4*n_yguards])
        testdata[:,:] = numpy.arange(ny + 4*n_yguards)[numpy.newaxis,:]
        gridfile.createVariable('test', float, ('x', 'y'))
        gridfile['test'][...] = testdata

# grid files for single-null:
for n_yguards in [0, 1, 2]:
    datadir = "data-singlenull-" + str(n_yguards)
    gridname = "grid-singlenull-" + str(n_yguards) + ".nc"

    with Dataset(os.path.join(datadir,gridname), 'w') as gridfile:
        gridfile.createDimension('x', nx)
        gridfile.createDimension('y', ny + 2*n_yguards)

        gridfile.createVariable('nx', numpy.int32)
        gridfile['nx'][...] = nx

        gridfile.createVariable('ny', numpy.int32)
        gridfile['ny'][...] = ny

        gridfile.createVariable('y_boundary_guards', numpy.int32)
        gridfile['y_boundary_guards'][...] = n_yguards

        gridfile.createVariable('MXG', numpy.int32)
        gridfile['MXG'][...] = 1

        gridfile.createVariable('MYG', numpy.int32)
        gridfile['MYG'][...] = 2 if n_yguards==0 else n_yguards

        gridfile.createVariable('ixseps1', numpy.int32)
        gridfile['ixseps1'][...] = nx//2 - 1

        gridfile.createVariable('ixseps2', numpy.int32)
        gridfile['ixseps2'][...] = nx//2  - 1

        gridfile.createVariable('jyseps1_1', numpy.int32)
        gridfile['jyseps1_1'][...] = blocksize - 1

        gridfile.createVariable('jyseps2_1', numpy.int32)
        gridfile['jyseps2_1'][...] = ny//2

        gridfile.createVariable('ny_inner', numpy.int32)
        gridfile['ny_inner'][...] = ny//2

        gridfile.createVariable('jyseps1_2', numpy.int32)
        gridfile['jyseps1_2'][...] = ny//2

        gridfile.createVariable('jyseps2_2', numpy.int32)
        gridfile['jyseps2_2'][...] = 5*blocksize - 1

        testdata = numpy.zeros([nx, ny + 2*n_yguards])
        testdata[:,:] = numpy.arange(ny + 2*n_yguards)[numpy.newaxis,:]
        gridfile.createVariable('test', float, ('x', 'y'))
        gridfile['test'][...] = testdata


for nproc in [6]:
  stdout.write("Checking %d processors ... " % (nproc))

  shell("rm ./data*/BOUT.dmp.*.nc run.log.*")

  success = True

  # double null tests
  for n_yguards in [0, 1, 2]:
      datadir = "data-doublenull-" + str(n_yguards)

      s, out = launch_safe("./test_griddata -d " + datadir, nproc=nproc, pipe=True)

      with open("run.log.doublenull."+str(nproc), "a") as f:
        f.write(out)

      testfield = collect("test", path=datadir, info=False, yguards=True)
  
      if n_yguards == 0:
          # output has 2 y-guard cells, but grid file did not
          myg = 2
          checkfield = list(numpy.zeros(myg))
          checkfield += list(numpy.arange(ny//2))
          checkfield += list(numpy.arange(ny//2) + checkfield[-1]  + 1)
          checkfield += list(numpy.zeros(myg) + checkfield[-1])
      else:
          checkfield = []
          checkfield += list(numpy.arange(n_yguards))
          checkfield += list(numpy.arange(ny//2) + checkfield[-1] + 1)
          checkfield += list(numpy.arange(ny//2) + checkfield[-1]  + 1 + 2*n_yguards)
          checkfield += list(numpy.arange(n_yguards) + checkfield[-1] + 1)
      checkfield = numpy.array(checkfield)

      # Test value of testfield
      if numpy.max(numpy.abs(testfield - checkfield)) > 1e-13:
        print("Failed: testfield does not match in doublenull case for n_yguards="+str(n_yguards))
        success = False

  # single null tests
  for n_yguards in [0, 1, 2]:
      datadir = "data-singlenull-" + str(n_yguards)

      s, out = launch_safe("./test_griddata -d " + datadir, nproc=nproc, pipe=True)

      with open("run.log.singlenull."+str(nproc), "a") as f:
        f.write(out)

      testfield = collect("test", path=datadir, info=False, yguards=True)
  
      if n_yguards == 0:
          # output has 2 y-guard cells, but grid file did not
          myg = 2
          checkfield = list(numpy.zeros(myg))
          checkfield += list(numpy.arange(ny))
          checkfield += list(numpy.zeros(myg) + checkfield[-1])
      else:
          checkfield = []
          checkfield += list(numpy.arange(n_yguards))
          checkfield += list(numpy.arange(ny) + checkfield[-1] + 1)
          checkfield += list(numpy.arange(n_yguards) + checkfield[-1] + 1)
      checkfield = numpy.array(checkfield)

      # Test value of testfield
      if numpy.max(numpy.abs(testfield - checkfield)) > 1e-13:
        print("Failed: testfield does not match in doublenull case for n_yguards="+str(n_yguards))
        success = False

  if not success:
      exit(1)
  else:
      print("Passed")

exit(0)
